import torch
import torch.nn as nn
from mmcv.cnn import (Linear,build_activation_layer,build_norm_layer,
                      xavier_init)   # notice here

from .builder import TRANSFORMER

class MultiheadAttention(nn.Module):
    """A warpper for torch.nn.MultheadAttention.

    T module implements MultiheadAttention with residual connection,
    and positional encoding used in DETR is also passed as input.

    Args:
        embed_dims(int):The embedding dimension.
        num_heads(int):Prallel attention heads.Same as 'nn.MultiheadAttetion',
        dropout(float):A Dropout layer on attn_output_weights.Default 0.0.
    """

    def __init__(self,embed_dims,num_heads,dropout=0.0):
        super(MultiheadAttention,self).__init__()
        assert embed_dims % num_heads == 0,'embed_dims must be'\
        f'divisible by num_heads.got{embed_dims} and {num_heads}.'
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.dropout = dropout
        self.attn = nn.MultiheadAttetion(embed_dims,num_heads,dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x,
                key=None,
                value=None,
                residual=None,
                query_pos=None,
                key_pos=None,
                attn_mask=None,
                key_padding_mask=None):
        """
        Forward function for 'MultiheadAttention',

        Args：
            x(Tensor):The input query with shape[num_query,bs,embed_dims].Same in 'nn.MultiheadAttention.forward'.
            key(Tensor):The key tensor with shape[num_key,bs,embed_dims].Same in 'nn.MultiheadAttention.forward'.
                Default None.If None,the 'query' will be used.
            value(Tensor):The value tensor with same shape as 'key'.Same in 'nn.MultiheadAttention.forward'.
                If None,the 'key' will be used.
            sidual(Tensor):The tensor used for addition,with the same shape as 'x'.
                Default None.If None,'x' will be used.
            query_pos(Tensor):The positional encoding of query,with the same shape as 'x'.
                Default None.If not None,it will be added to 'x' before forward function.
            key_pos(Tsor):The positional encoding for 'key',with the same shape as 'key'.
                Default None.If not None,it will be added to 'key' before forward function.
                If None,and 'query_pos' has the same shape as 'key',then 'query_pos' will be used for 'key_pos'.
            attn_mask(Tensor):ByteTensor mask with shape[num_query,num_key].Same in 'nn.MultiheadAttention.forward'.
                Default None.
            key_padding_mask(Tensor):ByteTensor with shape[bs,num_key].Same in 'nn.MultiheadAttention.forward'.
                Default None.

        Returns:
            Tensor:forwarded results with shape [num_query,bs,embed_dims].
        """
        query = x
        if key is None:
            key = query
        if value is None:
            value = key
        if residual is None:
            residual = x
        if key_pos is None:
            if query_pos is not None and key is not None:
                if query_pos.shape == key.shape:
                    key_pos=query_pos
        if query_pos is not None:
            query = query + query_pos
        if key_pos is not None:
            key = key + key_pos
        out = self.attn(query,
                        key,
                        value=value,
                        attn_mask=attn_mask,
                        key_padding_mask=key_padding_mask)[0]

        return residual+self.dropout(out)

    def __repr__(self):
        """str:a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(embed_dims={self.embed_dims},'
        repr_str += f'num_heads={self.num_heads},'
        repr_str += f'dropout={self.dropout})'
        return repr_str

class FFN(nn.Module):
    """implements feed-forward networks(FFNs) with residual connection.

    Args:
        embed_dims(int):The feature dimension.Same as 'MultiheadAttention'.
        feedforward_channels(int):The hidden dimension of FFNs.
        num_fcs(int):The number of fully_connected layers in FFNs.
        act_cfg(dict):The activation config for FFNs.
        dropout(float):Probability of an element to zeroed.Default 0.0.
    """

    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 num_fcs=2,
                 act_cfg=dict(type='ReLU',inplace=True),
                 dropout=0.0,
                 add_residual=True):
        super(FFN,self).__init__()
        assert num_fcs >=2,'num_fcs should be no less'\
        f'than 2.got{num_fcs}'
        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.num_fcs = num_fcs
        self.act_cfg = act_cfg
        self.dropout = dropout
        self.activate = build_activation_layer(act_cfg)

        layers = nn.ModuleList()
        in_channels = embed_dims
        for _ in range(num_fcs-1):
            layers.append(
                nn.Sequntial(
                    Linear(in_channels,feedforward_channels),self.activate,
                    nn.Dropout(dropout)
                )
            )
            in_channels = feedforward_channels
        layers.append(Linear(feedforward_channels,embed_dims))
        self.layers = nn.Sequential(*layers)
        self.dropout = nn.Dropout(dropout)
        self.add_residual = add_residual

    def forward(self,x,residual=None):
        """Forward function for 'FFN'."""
        out = self.layers(x)
        if not self.add_residual:
            return out
        if residual is None:
            residual = x
        return residual + self.dropout(out)

    def __repr__(self):
        """str:a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(embed_dims={self.embed_dims},'
        repr_str += f'feedforward_channels={self.feedforward_channels},'
        repr_str += f'num_fcs={self.num_fcs},'
        repr_str += f'act_cfg={self.act_cfg},'
        repr_str += f'dropout={self.dropout},'
        repr_str += f'add_residual={self.add_residual})'
        return repr_str

class TransformerEncoderLayer(nn.Module):
    """Implements one encoder layer in DETR transformer.

    Args:
        embed_dims(int):The feature dimention.Same as 'FFN',
        num_heads(int):Parallel attention heads.
        feedforward_channels(int):The hidden dimension for FFNs.
        dropout(float):Probability of an element to be zeroed.Default 0.0.
        order(tuple[str]):The order for encoder layer.Valid examples are
            ('selfattn','norm','ffn','norm') and ('norm','selfattn','norm','ffn').Default ('selfattn','norm','ffn','norm')
        act_cfg(dict):The activation config for FFNs.Default ReLU.
        norm_cfg(dict):Config dict for normalizatio layer.Default layer normalization.
        num_fcs(int):The number of fully-connected layers for FFNs.
            Default 2.
    """

    def __init__(self,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 dropout=0.0,
                 order=('selfattn','norm','ffn','norm'),
                 act_cfg=dict(type='ReLU',inplace=True),
                 norm_cfg=dict(type='LN'),  # notice here,not BN
                 num_fcs=2):
        super(TransformerEncoderLayer,self).__init__()
        assert isinstance(order,tuple) and len(order)==4
        assert set(order) == set(['selfattn','norm','ffn'])
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.feedforward_channels = feedforward_channels
        self.dropout = dropout
        self.order = order
        self.act_cfg = act_cfg
        self.norm_cfg = norm_cfg
        self.pre_norm = order[0] == 'norm'
        self.self_attn = MultiheadAttention(embed_dims,num_heads,dropout)
        self.ffn = FFN(embed_dims,feedforward_channels,num_fcs,act_cfg,dropout)
        self.norms = nn.ModuleList()
        self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
        self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])  # notice here,for twice

    def forward(self,x,pos=None,attn_mask=None,key_padding_mask=None):
        """
        Forward function for 'TransformerEncoderLayer'.

        Args:
            x(Tensor):The input query with shape[num_key,bs,embed_dims],Same in 'MultiheadAention.forward'.
            pos(Tensor):The positional encoding for query.Default None.
                Same as 'query_pos' in 'MultiheadAttention.forward'.
           attn_mask(Tensor):ByteTensor mask with shape[num_key,num_key].Same in 'MultiheadAttention.forward'.Default None.
            key_padding_mask(Tensor):ByteTensor with shape [bs,num_key].
                Same in 'MultiheadAttention.forward'.Default None.

        Returns:
            Tensor:forward results with shape[num_key,bs,embed_dims].
        """
        norm_cnt = 0
        inp_residual = x
        for layer in self.order:
            if layer =='selfattn':
                # self sttention
                query = key = value = x
                x = self.self_attn(
                    query,
                    key,
                    value,
                    inp_residual if self.pre_norm else None,
                    query_pos = pos,
                    attn_mask=attn_mask,
                    key_padding_mask=key_padding_mask
                )
                inp_residual = x
            elif layer == 'norm':
                x = self.norms[norm_cnt](x)
                norm_cnt += 1
            elif layer=='ffn':
                x = self.ffn(x,inp_residual if self.pre_norm else None)
        return x

    def __repr__(self):
        """str:a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(embed_dims={self.embed_dims},'
        repr_str += f'num_heads={self.num_heads},'
        repr_str += f'feedforward_channels={self.feedforward_channels},'
        repr_str += f'dropout={self.dropout},'
        repr_str += f'order={self.order},'
        repr_str += f'act_cfg={self.act_cfg},'
        repr_str += f'norm_cfg={self.norm_cfg},'
        repr_str += f'num_fcs={self.num_fcs})'
        return repr_str

class TransformerDecoderLayer(nn.Module):
    """Implements one decoder layer in DETR transformer.

    Args:
        embed_dims(int):The feature dimension.Same as 'TransformerEncoderLayer'.
        num_heads(int):Parallel attention heads.
        feedforward_channels(int):Same as 'TransformerEncoderLayer'.
        dropout(float):Same as 'TransformerEncoderLayer'.Default 0.0.
        order(tuple[str]):The order for decoder layer.Valid examples are
            ('selfattn','norm','multiheadattn','norm','ffn','norm') and
            ('norm','selfattn','norm','multiheadattn','norm','ffn') .
            Default the former.
        act_cfg(dict):Config dict for normalization layer.Default layer normalization.
        num_cfg(dict):The number of fully-connected layers in FFNs.
    """

    def __init__(self,
                embed_dims,
                num_heads,
                feedforward_channels,
                dropout=0.0,
                order=('selfattn','norm','multiheadattn','norm','ffn','norm'),
                act_cfg=dict(type='ReLU',inplace=True),
                norm_cfg=dict(type='LN'),
                num_fcs=2):
        super(TransformerDecoderLayer,self).__init__()
        assert isinstance(order,tuple) and len(order)==6
        assert set(order)==set(['selfattn','norm','multiheadattn','ffn'])
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.feedforward_channels = feedforward_channels
        self.dropout = dropout
        self.order = order
        self.act_cfg = act_cfg
        self.norm_cfg = norm_cfg
        self.num_fcs = num_fcs
        self.pre_norm = order[0]=='norm'
        self.self_attn = MultiheadAttention(embed_dims,num_heads,dropout)
        self.multihead_attn = MultiheadAttention(embed_dims,num_heads,dropout)
        self.ffn = FFN(embed_dims,feedforward_channels,num_fcs,act_cfg,dropout)
        self.norms = nn.ModuleList()
        # 3.norm layers in official DETR's TransformerDecoderLayer
        for _ in range(3):
            self.norms.append(build_norm_layer(norm_cfg,embed_dims)[1])  # notice here,for three times

    def forward(self,
                x,
                memory,
                memory_pos=None,
                query_pos=None,
                memory_attn_mask=None,
                target_attn_mask=None,
                memory_key_padding_mask=None,
                target_key_padding_mask=None):

        """
        Forward function for 'TransformerDecoderLayer',

        Args:
            x(Tensor):Input query with shape[num_query,embed_dims].
            memory(Tensor):Tensor got from 'TransformerEncoder',with shape[num_key,embed_dims].
            memory_pos(Tensor):The positional encoding for 'memory'.Default None.
                Same as 'key_pos' in 'MultiheadAttention.forward'.
            query_pos(Tensor):The positional encoding for 'query'.Default None.
                Same as 'query_pos' in 'MultiheadAttention.forward'.
            memory_attn_mask(Tensor):ByteTensor mask for 'memory',with shape[num_key,num_key].
                Same as 'attn_mask' in 'MultiheadAttention.forward'.Default None.
            memory_key_padding_mask(Tensor):ByteTensor for 'memory',with shape[bs,num_key].
                Same as 'key_padding_mask' in 'MultiheadAttention.forward'.Default None.
            target_key_padding_mask(Tensor):ByteTensor for 'x',with shape[bs,num_key].
                Same as 'key_padding_mask' in 'MultiheadAttention.forward'.Default None.

        Returns:
            Tensor:forwarded results with shape[num_query,embed_dims]
        """
        norm_cnt = 0
        inp_residual = x
        for layer in self.order:
            if layer == 'selfattn':
                query = key = value = x
                x = self.self_attn(
                    query,
                    key,
                    value,
                    inp_residual if self.pre_norm else None,
                    query_pos,
                    key_pos = query_pos,
                    attn_mask = target_attn_mask,
                    key_padding_mask = target_key_padding_mask
                )
                inp_residual = x

            elif layer=='norm':
                x = self.norms[norm_cnt](x)
                norm_cnt += 1
            elif layer=='multiheadattn':
                query = x
                key = value = memory
                x = self.multihead_attn(
                    query,
                    key,
                    value,
                    inp_residual if self.pre_norm else None,
                    query_pos,
                    key_pos=memory_pos,
                    attn_mask = memory_attn_mask,
                    key_padding_mask=memory_key_padding_mask
                )
                inp_residual = x
            elif layer=='ffn':
                x = self.ffn(x,inp_residual if self.pre_norm else None)
        return x

    def __repr__(self):
        """str: a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(embed_dims={self.embed_dims}, '
        repr_str += f'num_heads={self.num_heads}, '
        repr_str += f'feedforward_channels={self.feedforward_channels}, '
        repr_str += f'dropout={self.dropout}, '
        repr_str += f'order={self.order}, '
        repr_str += f'act_cfg={self.act_cfg}, '
        repr_str += f'norm_cfg={self.norm_cfg}, '
        repr_str += f'num_fcs={self.num_fcs})'
        return repr_str

class TransformerEncoder(nn.Module):
    """Implements the encoder in DETR transformer.

        Args:
            num_layers (int): The number of `TransformerEncoderLayer`.
            embed_dims (int): Same as `TransformerEncoderLayer`.
            num_heads (int): Same as `TransformerEncoderLayer`.
            feedforward_channels (int): Same as `TransformerEncoderLayer`.
            dropout (float): Same as `TransformerEncoderLayer`. Default 0.0.
            order (tuple[str]): Same as `TransformerEncoderLayer`.
            act_cfg (dict): Same as `TransformerEncoderLayer`. Defalut ReLU.
            norm_cfg (dict): Same as `TransformerEncoderLayer`. Default
                layer normalization.
            num_fcs (int): Same as `TransformerEncoderLayer`. Default 2.
        """

    def __init__(self,
                 num_layers,embed_dims,
                 num_heads,
                 feedforward_channels,
                 dropout=0.0,
                 order=('selfattn','norm','ffn','norm'),
                 act_cfg=dict(type='ReLU',inplace=True),
                 norm_cfg=dict(type='LN'),
                 num_fcs=2):
        super(TransformerEncoder,self).__init__()
        assert isinstance(order,tuple) and len(order)==4
        assert set(order)==set(['selfattn','norm','ffn'])
        self.num_layers = num_layers
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.feedforward_channels = feedforward_channels
        self.dropout = dropout
        self.order = order
        self.act_cfg = act_cfg
        self.norm_cfg = norm_cfg
        self.num_fcs = num_fcs
        self.pre_norm = order[0] == 'norm'
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(TransformerEncoderLayer(embed_dims,num_heads.feedforward_channels,dropout,order,
                                                       act_cfg,norm_cfg,num_fcs))
        self.norm = build_norm_layer(norm_cfg,
                                     embed_dims)[1] if self.pre_norm else None

    def forward(self,x,pos=None,attn_mask=None,key_padding_mask=None):
        """Forward function for `TransformerEncoder`.

            Args:
                x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`.
                pos (Tensor): Positional encoding for query. Default None.
                    Same in `TransformerEncoderLayer.forward`.
                attn_mask (Tensor): ByteTensor attention mask. Default None.
                    Same in `TransformerEncoderLayer.forward`.
                key_padding_mask (Tensor): Same in
                    `TransformerEncoderLayer.forward`. Default None.

            Returns:
                Tensor: Results with shape [num_key, bs, embed_dims].
        """

        for layer in self.layers:
            x = layer(x,pos,attn_mask,key_padding_mask)
        if self.norm is not None:
            x = self.norm(x)
        return x

    def __repr__(self):
        """str: a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(num_layers={self.num_layers}, '
        repr_str += f'embed_dims={self.embed_dims}, '
        repr_str += f'num_heads={self.num_heads}, '
        repr_str += f'feedforward_channels={self.feedforward_channels}, '
        repr_str += f'dropout={self.dropout}, '
        repr_str += f'order={self.order}, '
        repr_str += f'act_cfg={self.act_cfg}, '
        repr_str += f'norm_cfg={self.norm_cfg}, '
        repr_str += f'num_fcs={self.num_fcs})'
        return repr_str

class TransformerDecoder(nn.Module):
    """
    Implements the decoder in DETR transformer.

    Args:
        num_layers (int): The number of `TransformerDecoderLayer`.
        embed_dims (int): Same as `TransformerDecoderLayer`.
        num_heads (int): Same as `TransformerDecoderLayer`.
        feedforward_channels (int): Same as `TransformerDecoderLayer`.
        dropout (float): Same as `TransformerDecoderLayer`. Default 0.0.
        order (tuple[str]): Same as `TransformerDecoderLayer`.
        act_cfg (dict): Same as `TransformerDecoderLayer`. Defalut ReLU.
        norm_cfg (dict): Same as `TransformerDecoderLayer`. Default
            layer normalization.
        num_fcs (int): Same as `TransformerDecoderLayer`. Default 2.
    """
    def __init__(self,
                 num_layers,
                 embed_dims,
                 num_heads,
                 feedforward_channels,
                 dropout=0.0,
                 order=('selfattn','norm','multiheadattn','norm','ffn','norm'),
                 act_cfg=dict(type='ReLU',inplace=True),
                 norm_cfg=dict(type='LN'),
                 num_fcs=2,
                 return_intermediate=False):
        super(TransformerDecoder,self).__init__()
        assert isinstance(order, tuple) and len(order) == 6
        assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn'])
        self.num_layers = num_layers
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.feedforward_channels = feedforward_channels
        self.dropout = dropout
        self.order = order
        self.act_cfg = act_cfg
        self.norm_cfg = norm_cfg
        self.num_fcs = num_fcs
        self.return_intermediate = return_intermediate
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(
                TransformerDecoderLayer(embed_dims,num_heads,feedforward_channels,dropout,order,act_cfg,norm_cfg,num_fcs)
            )
            self.norm = build_norm_layer(norm_cfg,embed_dims)[1]

    def forward(
            self,
            x,
            memory,
            memory_pos=None,
            query_pos=None,
            memory_attn_mask=None,
            target_attn_mask=None,
            memory_key_padding_mask=None,
            target_key_padding_mask=None
    ):


        """Forward function for `TransformerDecoder`.

                Args:
                    x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`.
                    memory (Tensor): Same in `TransformerDecoderLayer.forward`.
                    memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`.
                        Default None.
                    query_pos (Tensor): Same in `TransformerDecoderLayer.forward`.
                        Default None.
                    memory_attn_mask (Tensor): Same in
                        `TransformerDecoderLayer.forward`. Default None.
                    target_attn_mask (Tensor): Same in
                        `TransformerDecoderLayer.forward`. Default None.
                    memory_key_padding_mask (Tensor): Same in
                        `TransformerDecoderLayer.forward`. Default None.
                    target_key_padding_mask (Tensor): Same in
                        `TransformerDecoderLayer.forward`. Default None.

                Returns:
                    Tensor: Results with shape [num_query, bs, embed_dims].
                """
        intermediate = []
        for layer in self.layers:
            x = layer(x, memory, memory_pos, query_pos, memory_attn_mask,
                      target_attn_mask, memory_key_padding_mask,
                      target_key_padding_mask)
            if self.return_intermediate:
                intermediate.append(self.norm(x))
        if self.norm is not None:
            x = self.norm(x)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(x)
        if self.return_intermediate:
            return torch.stack(intermediate)
        return x.unsqueeze(0)

    def __repr__(self):
        """str: a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(num_layers={self.num_layers}, '
        repr_str += f'embed_dims={self.embed_dims}, '
        repr_str += f'num_heads={self.num_heads}, '
        repr_str += f'feedforward_channels={self.feedforward_channels}, '
        repr_str += f'dropout={self.dropout}, '
        repr_str += f'order={self.order}, '
        repr_str += f'act_cfg={self.act_cfg}, '
        repr_str += f'norm_cfg={self.norm_cfg}, '
        repr_str += f'num_fcs={self.num_fcs}, '
        repr_str += f'return_intermediate={self.return_intermediate})'
        return repr_str


@TRANSFORMER.register_module()
class Transformer(nn.Module):
    """Implements the DETR transformer.

        Following the official DETR implementation, this module copy-paste
        from torch.nn.Transformer with modifications:

            * positional encodings are passed in MultiheadAttention
            * extra LN at the end of encoder is removed
            * decoder returns a stack of activations from all decoding layers

        See `paper: End-to-End Object Detection with Transformers
        <https://arxiv.org/pdf/2005.12872>`_ for details.

        Args:
            embed_dims (int): The feature dimension.
            num_heads (int): Parallel attention heads. Same as
                `nn.MultiheadAttention`.
            num_encoder_layers (int): Number of `TransformerEncoderLayer`.
            num_decoder_layers (int): Number of `TransformerDecoderLayer`.
            feedforward_channels (int): The hidden dimension for FFNs used in both
                encoder and decoder.
            dropout (float): Probability of an element to be zeroed. Default 0.0.
            act_cfg (dict): Activation config for FFNs used in both encoder
                and decoder. Defalut ReLU.
            norm_cfg (dict): Config dict for normalization used in both encoder
                and decoder. Default layer normalization.
            num_fcs (int): The number of fully-connected layers in FFNs, which is
                used for both encoder and decoder.
            pre_norm (bool): Whether the normalization layer is ordered
                first in the encoder and decoder. Default False.
            return_intermediate_dec (bool): Whether to return the intermediate
                output from each TransformerDecoderLayer or only the last
                TransformerDecoderLayer. Default False. If False, the returned
                `hs` has shape [num_decoder_layers, bs, num_query, embed_dims].
                If True, the returned `hs` will have shape [1, bs, num_query,
                embed_dims].
        """

    def __init__(self,
                 embed_dims=512,
                 num_heads=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 feedforward_channels=2048,
                 dropout=0.0,
                 act_cfg=dict(type='ReLU',inplace=True),
                 norm_cfg=dict(type='LN'),
                 num_fcs=2,
                 pre_norm=False,
                 return_intermediate_dec=False
                 ):
        super(Transformer,self).__init__()
        self.embed_dims = embed_dims
        self.num_heads = num_heads
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.feedforward_channels = feedforward_channels
        self.dropout = dropout
        self.act_cfg = act_cfg
        self.norm_cfg = norm_cfg
        self.num_fcs = num_fcs
        self.pre_norm = pre_norm
        self.return_intermediate_dec = return_intermediate_dec
        if self.pre_norm:
            encoder_order = ('norm', 'selfattn', 'norm', 'ffn')
            decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn',
                             'norm', 'ffn')
        else:
            encoder_order = ('selfattn', 'norm', 'ffn', 'norm')
            decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm',
                             'ffn', 'norm')
        self.encoder = TransformerEncoder(num_encoder_layers, embed_dims,
                                          num_heads, feedforward_channels,
                                          dropout, encoder_order, act_cfg,
                                          norm_cfg, num_fcs)
        self.decoder = TransformerDecoder(num_decoder_layers, embed_dims,
                                          num_heads, feedforward_channels,
                                          dropout, decoder_order, act_cfg,
                                          norm_cfg, num_fcs,
                                          return_intermediate_dec)

    def init_weights(self,distribution='uniform'):
        """Initialize the transformer weights."""
        # follow the official DETR to init parameters
        for m in self.modules():
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                xavier_init(m, distribution=distribution)

    def forward(self, x, mask, query_embed, pos_embed):
        """Forward function for `Transformer`.

        Args:
            x (Tensor): Input query with shape [bs, c, h, w] where
                c = embed_dims.
            mask (Tensor): The key_padding_mask used for encoder and decoder,
                with shape [bs, h, w].
            query_embed (Tensor): The query embedding for decoder, with shape
                [num_query, c].
            pos_embed (Tensor): The positional encoding for encoder and
                decoder, with the same shape as `x`.

        Returns:
            tuple[Tensor]: results of decoder containing the following tensor.

                - out_dec: Output from decoder. If return_intermediate_dec \
                      is True output has shape [num_dec_layers, bs,
                      num_query, embed_dims], else has shape [1, bs, \
                      num_query, embed_dims].
                - memory: Output results from encoder, with shape \
                      [bs, embed_dims, h, w].
        """
        bs, c, h, w = x.shape
        x = x.flatten(2).permute(2, 0, 1)  # [bs, c, h, w] -> [h*w, bs, c]
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        query_embed = query_embed.unsqueeze(1).repeat(
            1, bs, 1)  # [num_query, dim] -> [num_query, bs, dim]
        mask = mask.flatten(1)  # [bs, h, w] -> [bs, h*w]
        memory = self.encoder(
            x, pos=pos_embed, attn_mask=None, key_padding_mask=mask)
        target = torch.zeros_like(query_embed)
        # out_dec: [num_layers, num_query, bs, dim]
        out_dec = self.decoder(
            target,
            memory,
            memory_pos=pos_embed,
            query_pos=query_embed,
            memory_attn_mask=None,
            target_attn_mask=None,
            memory_key_padding_mask=mask,
            target_key_padding_mask=None)
        out_dec = out_dec.transpose(1, 2)
        memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
        return out_dec, memory

    def __repr__(self):
        """str: a string that describes the module"""
        repr_str = self.__class__.__name__
        repr_str += f'(embed_dims={self.embed_dims}, '
        repr_str += f'num_heads={self.num_heads}, '
        repr_str += f'num_encoder_layers={self.num_encoder_layers}, '
        repr_str += f'num_decoder_layers={self.num_decoder_layers}, '
        repr_str += f'feedforward_channels={self.feedforward_channels}, '
        repr_str += f'dropout={self.dropout}, '
        repr_str += f'act_cfg={self.act_cfg}, '
        repr_str += f'norm_cfg={self.norm_cfg}, '
        repr_str += f'num_fcs={self.num_fcs}, '
        repr_str += f'pre_norm={self.pre_norm}, '
        repr_str += f'return_intermediate_dec={self.return_intermediate_dec})'
        return repr_str